# -- coding: utf-8 --
"""
Copyright (c) 2018. All rights reserved.
Created by C. L. Wang on 2018/4/18
"""
import os

from base.base_eval import BaseEval
from keras.models import load_model


class SimpleMnistEval(BaseEval):
    def __init__(self, name, config=None):
        super(SimpleMnistEval, self).__init__(config)
        self.model = self.load_model(name)

    def load_model(self, name):
        model_path = os.path.join(self.config.cp_dir, name)
        return load_model(model_path)

    def predict(self, data):
        return self.model.predict(data)
